{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/antonio/anaconda3/envs/geometric_new/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.11.0\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['TORCH'] = torch.__version__\n",
    "print(torch.__version__)\n",
    "\n",
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch_geometric\n",
    "from torch_geometric.datasets import Planetoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_cuda_if_available = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n",
      "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "dataset = Planetoid(root=\"tutorial1\",name= \"Cora\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Dataset properties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cora()\n",
      "number of graphs:\t\t 1\n",
      "number of classes:\t\t 7\n",
      "number of node features:\t 1433\n",
      "number of edge features:\t 0\n"
     ]
    }
   ],
   "source": [
    "print(dataset)\n",
    "print(\"number of graphs:\\t\\t\",len(dataset))\n",
    "print(\"number of classes:\\t\\t\",dataset.num_classes)\n",
    "print(\"number of node features:\\t\",dataset.num_node_features)\n",
    "print(\"number of edge features:\\t\",dataset.num_edge_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Dataset shapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])\n"
     ]
    }
   ],
   "source": [
    "print(dataset.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "edge_index:\t\t torch.Size([2, 10556])\n",
      "tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],\n",
      "        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])\n",
      "\n",
      "\n",
      "train_mask:\t\t torch.Size([2708])\n",
      "tensor([ True,  True,  True,  ..., False, False, False])\n",
      "\n",
      "\n",
      "x:\t\t torch.Size([2708, 1433])\n",
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]])\n",
      "\n",
      "\n",
      "y:\t\t torch.Size([2708])\n",
      "tensor([3, 4, 4,  ..., 3, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "print(\"edge_index:\\t\\t\",dataset.data.edge_index.shape)\n",
    "print(dataset.data.edge_index)\n",
    "print(\"\\n\")\n",
    "print(\"train_mask:\\t\\t\",dataset.data.train_mask.shape)\n",
    "print(dataset.data.train_mask)\n",
    "print(\"\\n\")\n",
    "print(\"x:\\t\\t\",dataset.data.x.shape)\n",
    "print(dataset.data.x)\n",
    "print(\"\\n\")\n",
    "print(\"y:\\t\\t\",dataset.data.y.shape)\n",
    "print(dataset.data.y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import SAGEConv\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        self.conv = SAGEConv(dataset.num_features,\n",
    "                             dataset.num_classes,\n",
    "                             aggr=\"max\") # max, mean, add ...)\n",
    "\n",
    "    def forward(self):\n",
    "        x = self.conv(data.x, data.edge_index)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')\n",
    "model, data = Net().to(device), data.to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "def test():\n",
    "    model.eval()\n",
    "    logits, accs = model(), []\n",
    "    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n",
    "        pred = logits[mask].max(1)[1]\n",
    "        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n",
    "        accs.append(acc)\n",
    "    return accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 010, Val: 0.7260, Test: 0.7140\n",
      "Epoch: 020, Val: 0.7260, Test: 0.7140\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m100\u001b[39m):\n\u001b[1;32m      3\u001b[0m     train()\n\u001b[0;32m----> 4\u001b[0m     _, val_acc, tmp_test_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m val_acc \u001b[38;5;241m>\u001b[39m best_val_acc:\n\u001b[1;32m      6\u001b[0m         best_val_acc \u001b[38;5;241m=\u001b[39m val_acc\n",
      "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36mtest\u001b[0;34m()\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtest\u001b[39m():\n\u001b[1;32m      9\u001b[0m     model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m---> 10\u001b[0m     logits, accs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, []\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m _, mask \u001b[38;5;129;01min\u001b[39;00m data(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_mask\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_mask\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_mask\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m     12\u001b[0m         pred \u001b[38;5;241m=\u001b[39m logits[mask]\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [13]\u001b[0m, in \u001b[0;36mNet.forward\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m---> 10\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mlog_softmax(x, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch_geometric/nn/conv/sage_conv.py:123\u001b[0m, in \u001b[0;36mSAGEConv.forward\u001b[0;34m(self, x, edge_index, size)\u001b[0m\n\u001b[1;32m    120\u001b[0m     x \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin(x[\u001b[38;5;241m0\u001b[39m])\u001b[38;5;241m.\u001b[39mrelu(), x[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m    122\u001b[0m \u001b[38;5;66;03m# propagate_type: (x: OptPairTensor)\u001b[39;00m\n\u001b[0;32m--> 123\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpropagate\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    124\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin_l(out)\n\u001b[1;32m    126\u001b[0m x_r \u001b[38;5;241m=\u001b[39m x[\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py:352\u001b[0m, in \u001b[0;36mMessagePassing.propagate\u001b[0;34m(self, edge_index, size, **kwargs)\u001b[0m\n\u001b[1;32m    349\u001b[0m         aggr_kwargs \u001b[38;5;241m=\u001b[39m res[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(res, \u001b[38;5;28mtuple\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m res\n\u001b[1;32m    351\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maggrs) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 352\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maggregate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43maggr_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    353\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    354\u001b[0m     outs \u001b[38;5;241m=\u001b[39m []\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch_geometric/nn/conv/sage_conv.py:146\u001b[0m, in \u001b[0;36mSAGEConv.aggregate\u001b[0;34m(self, x, index, ptr, dim_size)\u001b[0m\n\u001b[1;32m    143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21maggregate\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor, index: Tensor, ptr: Optional[Tensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    144\u001b[0m               dim_size: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m    145\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maggr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 146\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mscatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdim_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    147\u001b[0m \u001b[43m                       \u001b[49m\u001b[43mreduce\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maggr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    149\u001b[0m     \u001b[38;5;66;03m# LSTM aggregation:\u001b[39;00m\n\u001b[1;32m    150\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m ptr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mall(index[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m index[\u001b[38;5;241m1\u001b[39m:]):\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch_scatter/scatter.py:160\u001b[0m, in \u001b[0;36mscatter\u001b[0;34m(src, index, dim, out, dim_size, reduce)\u001b[0m\n\u001b[1;32m    158\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m scatter_min(src, index, dim, out, dim_size)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    159\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m reduce \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m--> 160\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mscatter_max\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    161\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    162\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch_scatter/scatter.py:72\u001b[0m, in \u001b[0;36mscatter_max\u001b[0;34m(src, index, dim, out, dim_size)\u001b[0m\n\u001b[1;32m     68\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mscatter_max\u001b[39m(\n\u001b[1;32m     69\u001b[0m         src: torch\u001b[38;5;241m.\u001b[39mTensor, index: torch\u001b[38;5;241m.\u001b[39mTensor, dim: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m     70\u001b[0m         out: Optional[torch\u001b[38;5;241m.\u001b[39mTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m     71\u001b[0m         dim_size: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[torch\u001b[38;5;241m.\u001b[39mTensor, torch\u001b[38;5;241m.\u001b[39mTensor]:\n\u001b[0;32m---> 72\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mops\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtorch_scatter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscatter_max\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "best_val_acc = test_acc = 0\n",
    "for epoch in range(1,100):\n",
    "    train()\n",
    "    _, val_acc, tmp_test_acc = test()\n",
    "    if val_acc > best_val_acc:\n",
    "        best_val_acc = val_acc\n",
    "        test_acc = tmp_test_acc\n",
    "    log = 'Epoch: {:03d}, Val: {:.4f}, Test: {:.4f}'\n",
    "    \n",
    "    if epoch % 10 == 0:\n",
    "        print(log.format(epoch, best_val_acc, test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
